import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os



def _glorot_init(input_dim, output_dim):
    init_range = np.sqrt(6.0 / (input_dim + output_dim))
    initial = torch.rand(input_dim, output_dim) * 2 * init_range - init_range
#     initial = torch.randn(input_dim, output_dim)
    return nn.Parameter(initial)



class GraphConvSparse(nn.Module):
    def __init__(self, input_dim, output_dim, activation=F.relu, **kwargs):
        super(GraphConvSparse, self).__init__(**kwargs)
        self.weight = _glorot_init(input_dim, output_dim)
        self.activation = activation

    def forward(self, inputs, adj):
        x = inputs
        x = torch.mm(x, self.weight)
        x = torch.mm(adj, x)
        outputs = self.activation(x)
        return outputs

def dot_product_decode(A):
#     print(A[:, 1])
#     print(A[:, 5])
#     print(A[:, 10])
#     raise
    return torch.matmul(A, A.t())
    
class GAE(nn.Module):
    def __init__(self, input_dim, hidden1_dim, hidden2_dim):
        super(GAE, self).__init__()
        self.base_gcn = GraphConvSparse(input_dim, hidden1_dim)
        self.gcn_mean = GraphConvSparse(
            hidden1_dim, hidden2_dim, activation=lambda x:x)


    def encode(self, X, adj):
        hidden = self.base_gcn(X, adj)
#         print(hidden)
#         raise
        z = self.gcn_mean(hidden, adj)
        return z

    def forward(self, X, adj):
#         print(torch.sum(torch.matmul(X, X.t()) < 0))
#         raise
        Z = self.encode(X, adj)
        output = dot_product_decode(Z)
        return output





















































